import torch
import torch.nn as nn

import utils.graph_lib as Graphs
from models.radd import RADD
from models.uvit import UViT

class GraphPreconditioner(nn.Module):
    def __init__(self, net,  graph : Graphs.Absorbing, cfg_p=.1) -> None:
        super().__init__()
        self.net = net
        self.graph = graph
        assert 0 <= cfg_p and cfg_p <= 1, 'CFG Probability has to be in (0,1)'
        self.cfg_prob = cfg_p 
        
    def forward_with_cfg(self, x, sigma_int, cond, cfg_scale, force_condition_class=None):
        cfg_cond = - torch.ones_like(cond) if force_condition_class is None else torch.ones_like(cond) * force_condition_class
        cond_score = self(x,sigma_int,cond, return_score=False)
        uncond_score = self(x,sigma_int, cfg_cond, return_score=False)
        score_w = cfg_scale * cond_score + (1-cfg_scale) * uncond_score
        score_w[:, :, :-1] = score_w[:, :, :-1].log_softmax(dim=-1)
        return score_w
        
    def forward(self,x, sigma_int, cond, cfg_scale=1., return_score=False, force_condition_class=None):
        shift_cond = cond + 1 # We always add 1 to use 0 as the CFG empty token
        if cfg_scale == 1.:
            if self.training:
                u = torch.rand((cond.shape[0],*([1] * len(cond.shape[1:]))))
                shift_cond[u < self.cfg_prob] = 0.
            
            disc_score = self.net(x, shift_cond)
            
            if return_score:
                esigm1_log = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).log().to(x.dtype).view(-1,1,1)
                disc_score = disc_score - esigm1_log

            return disc_score 
        else:
            guid_score = self.forward_with_cfg(x,sigma_int, cond, cfg_scale, force_condition_class)
            if return_score:
                esigm1_log = torch.where(sigma_int < 0.5, torch.expm1(sigma_int), sigma_int.exp() - 1).log().to(x.dtype).view(-1,1,1)
                guid_score = guid_score - esigm1_log
            return guid_score

def get_model(name, vocab_size,context_len, net_opts):
    if name == 'radd':
        return RADD(vocab_size, context_len, **net_opts)
    elif name == 'uvit':
        return UViT(vocab_size, **net_opts)

def get_preconditioned_model(net, graph, cfg_prob=.1):
    return GraphPreconditioner(net,graph, cfg_p=cfg_prob)